{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding new objects to Nengo\n",
    "\n",
    "It is possible to add new objects\n",
    "to the Nengo reference simulator.\n",
    "This involves several steps and the creation\n",
    "of several objects.\n",
    "In this example, we'll go through these steps\n",
    "in order to add a new neuron type to Nengo:\n",
    "a rectified linear neuron.\n",
    "\n",
    "The `RectifiedLinear` class is what you will use\n",
    "in model scripts to denote that a particular ensemble\n",
    "should be simulated using a rectified linear neuron\n",
    "instead of one of the existing neuron types (e.g., `LIF`).\n",
    "\n",
    "Normally, these kinds of frontend classes exist\n",
    "in a file in the root `nengo` directory,\n",
    "like `nengo/neurons.py` or `nengo/synapses.py`.\n",
    "Look at these files for examples of how to make your own.\n",
    "In this case, because we're making a neuron type,\n",
    "we'll use `nengo.neurons.LIF` as an example\n",
    "of how to make `RectifiedLinear`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import nengo\n",
    "from nengo.utils.ensemble import tuning_curves\n",
    "\n",
    "\n",
    "# Neuron types must subclass `nengo.neurons.NeuronType`\n",
    "class RectifiedLinear(nengo.neurons.NeuronType):\n",
    "    \"\"\"A rectified linear neuron model.\"\"\"\n",
    "\n",
    "    # We don't need any additional parameters here;\n",
    "    # gain and bias are sufficient. But, if we wanted\n",
    "    # more parameters, we could accept them by creating\n",
    "    # an __init__ method.\n",
    "\n",
    "    def gain_bias(self, max_rates, intercepts):\n",
    "        \"\"\"Return gain and bias given maximum firing rate and x-intercept.\"\"\"\n",
    "        gain = max_rates / (1 - intercepts)\n",
    "        bias = -intercepts * gain\n",
    "        return gain, bias\n",
    "\n",
    "    def step(self, dt, J, output):\n",
    "        \"\"\"Compute rates in Hz for input current (incl. bias)\"\"\"\n",
    "        output[...] = np.maximum(0.0, J)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use `RectifiedLinear` like any other neuron type\n",
    "without making modifications to the reference simulator.\n",
    "However, other objects, including more complicated neuron types,\n",
    "may require changes to the reference simulator.\n",
    "\n",
    "## Tuning curves\n",
    "\n",
    "We can build a small network just to see the tuning curves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network()\n",
    "with model:\n",
    "    encoders = np.tile([[1], [-1]], (4, 1))\n",
    "    intercepts = np.linspace(-0.8, 0.8, 8)\n",
    "    intercepts *= encoders[:, 0]\n",
    "    A = nengo.Ensemble(\n",
    "        8,\n",
    "        dimensions=1,\n",
    "        intercepts=intercepts,\n",
    "        neuron_type=RectifiedLinear(),\n",
    "        max_rates=nengo.dists.Uniform(80, 100),\n",
    "        encoders=encoders,\n",
    "    )\n",
    "with nengo.Simulator(model) as sim:\n",
    "    eval_points, activities = tuning_curves(A, sim)\n",
    "plt.figure()\n",
    "plt.plot(eval_points, activities, lw=2)\n",
    "plt.xlabel(\"Input signal\")\n",
    "plt.ylabel(\"Firing rate (Hz)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2D Representation example\n",
    "\n",
    "Below is the same model as is made in the 2d_representation example,\n",
    "except now using `RectifiedLinear` neurons insated of `nengo.LIF`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network(label=\"2D Representation\", seed=10)\n",
    "with model:\n",
    "    neurons = nengo.Ensemble(100, dimensions=2, neuron_type=RectifiedLinear())\n",
    "    sin = nengo.Node(output=np.sin)\n",
    "    cos = nengo.Node(output=np.cos)\n",
    "    nengo.Connection(sin, neurons[0])\n",
    "    nengo.Connection(cos, neurons[1])\n",
    "    sin_probe = nengo.Probe(sin, \"output\")\n",
    "    cos_probe = nengo.Probe(cos, \"output\")\n",
    "    neurons_probe = nengo.Probe(neurons, \"decoded_output\", synapse=0.01)\n",
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(sim.trange(), sim.data[neurons_probe], label=\"Decoded output\")\n",
    "plt.plot(sim.trange(), sim.data[sin_probe], \"r\", label=\"Sine\")\n",
    "plt.plot(sim.trange(), sim.data[cos_probe], \"k\", label=\"Cosine\")\n",
    "plt.legend()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
